{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch with Kubeflow Fairing \n",
    "\n",
    "In this notebook we will walk through training a character recongition model using the MNIST dataset on Pytorch. \n",
    "We will then show you how to use Kubeflow Fairing to run the same training job on both Kubeflow and CMLE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#you can skip this step if you have already installed the necessary dependencies\n",
    "!pip install -U -r requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os\n",
    "import subprocess\n",
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "# For mac users you may get hit with this bug https://github.com/pytorch/pytorch/issues/20030\n",
    "# temporary solution is \"brew install libomp\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch Model Defintion\n",
    "\n",
    "Setup a Convolution Nueral network using Pytorch!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
    "        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
    "        self.conv2_drop = nn.Dropout2d()\n",
    "        self.fc1 = nn.Linear(320, 50)\n",
    "        self.fc2 = nn.Linear(50, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
    "        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
    "        x = x.view(-1, 320)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch Training and Test Functions\n",
    "A simple training function that batches the data set. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, device, train_loader, optimizer, epoch, log_interval):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if batch_idx % log_interval == 0 and batch_idx>0:\n",
    "            print('Train Epoch: {}\\t[{}/{}\\t({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "              epoch, batch_idx * len(data), len(train_loader.dataset),\n",
    "              100. * batch_idx / len(train_loader), loss.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, device, test_loader, epoch):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.nll_loss(\n",
    "              output, target, size_average=False).item()  # sum up batch loss\n",
    "            pred = output.max(\n",
    "              1, keepdim=True)[1]  # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "        test_loss /= len(test_loader.dataset)\n",
    "        print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "          test_loss, correct, len(test_loader.dataset),\n",
    "          100. * correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_and_test(batch_size=64, epochs=1, log_interval=100, lr=0.01, model_dir=None, momentum=0.5, \n",
    "                       no_cuda=False, seed=1, test_batch_size=1000):\n",
    "\n",
    "    use_cuda = not no_cuda and torch.cuda.is_available()\n",
    "    torch.manual_seed(seed)\n",
    "    device = torch.device('cuda' if use_cuda else 'cpu')\n",
    "    print(\"Using {} for training.\".format(device))\n",
    "\n",
    "    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "      datasets.MNIST(\n",
    "          'data',\n",
    "          train=True,\n",
    "          download=True,\n",
    "          transform=transforms.Compose([\n",
    "              transforms.ToTensor(),\n",
    "              # Normalize a tensor image with mean and standard deviation\n",
    "              transforms.Normalize(mean=(0.1307,), std=(0.3081,))\n",
    "          ])),\n",
    "      batch_size=batch_size,\n",
    "      shuffle=True,\n",
    "      **kwargs)\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "      datasets.MNIST(\n",
    "          'data',\n",
    "          train=False,\n",
    "          transform=transforms.Compose([\n",
    "              transforms.ToTensor(),\n",
    "              # Normalize a tensor image with mean and standard deviation              \n",
    "              transforms.Normalize(mean=(0.1307,), std=(0.3081,))\n",
    "          ])),\n",
    "      batch_size=test_batch_size,\n",
    "      shuffle=True,\n",
    "      **kwargs)\n",
    "\n",
    "    model = Net().to(device)\n",
    "    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)\n",
    "\n",
    "    for epoch in range(1, epochs + 1):\n",
    "        start_time = time.time()\n",
    "        train(model, device, train_loader, optimizer, epoch, log_interval)\n",
    "        print(\"Time taken for epoch #{}: {:.2f}s\".format(epoch, time.time()-start_time))\n",
    "        test(model, device, test_loader, epoch)\n",
    "\n",
    "    if model_dir:\n",
    "        model_file_name = 'torch.model'\n",
    "        tmp_model_file = os.path.join('/tmp', model_file_name)\n",
    "        torch.save(model.state_dict(), tmp_model_file)\n",
    "        subprocess.check_call([\n",
    "            'gsutil', 'cp', tmp_model_file,\n",
    "            os.path.join(model_dir, model_file_name)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training locally"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_and_test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fairing setup\n",
    "\n",
    "In this block we set some Docker config. Fairing will use this information to package up the `train_and_test` function i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from kubeflow import fairing\n",
    "from kubeflow.fairing import TrainJob\n",
    "import sys\n",
    "\n",
    "# Setting up google container repositories (GCR) for storing output containers\n",
    "# You can use any docker container registry istead of GCR\n",
    "GCP_PROJECT = fairing.cloud.gcp.guess_project_name()\n",
    "DOCKER_REGISTRY = 'gcr.io/{}/fairing-job'.format(GCP_PROJECT)\n",
    "PY_VERSION = \".\".join([str(x) for x in sys.version_info[0:3]])\n",
    "BASE_IMAGE = 'python:{}'.format(PY_VERSION)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training on Kubeflow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_job = TrainJob(train_and_test,\n",
    "                     base_docker_image=BASE_IMAGE,\n",
    "                     docker_registry=DOCKER_REGISTRY,\n",
    "                     backend= kubeflow.fairing.backends.KubeflowGK(),\n",
    "                     input_files=[\"requirements.txt\"])\n",
    "train_job.submit()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training on Google AI Paltorm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_job = TrainJob(lambda: train_and_test(epochs=100),\n",
    "                     base_docker_image=BASE_IMAGE,\n",
    "                     docker_registry=DOCKER_REGISTRY,\n",
    "                     backend= kubeflow.fairing.backends.GCPManagedBackend(),\n",
    "                     input_files=[\"requirements.txt\"])\n",
    "train_job.submit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
